//===--- TosaProfileCompliance.cpp - Tosa Profile Compliance Validation ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h"
#include "llvm/ADT/StringExtras.h"

using namespace mlir;
using namespace mlir::tosa;

TosaProfileCompliance::TosaProfileCompliance() {
  const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
  const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
  const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
  const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
  const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
  const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
  const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
  const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
  const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
  const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
  const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};

// The profile-based compliance content below is auto-generated by a script
// in https://git.mlplatform.org/tosa/specification.git
#include "mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc"
  // End of auto-generated metadata
}

template <>
OperationProfileComplianceMap TosaProfileCompliance::getProfileComplianceMap() {
  return profileComplianceMap;
}

template <>
OperationExtensionComplianceMap
TosaProfileCompliance::getProfileComplianceMap() {
  return extensionComplianceMap;
}

// Base populating function
LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands,
                                                    Value output) {
  for (auto operand : operands)
    addValue(operand);
  addValue(output);
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
  addValue(op.getInput1().front());
  addValue(op.getOutput());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
  addValue(op.getInput());
  addValue(op.getInputZp());
  addValue(op.getOutputZp());
  addType(op.getAccType());
  addValue(op.getOutput());
  return success();
}

template <typename T>
LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
  addValue(op.getInput());
  addValue(op.getWeight());
  addValue(op.getBias());
  addValue(op.getInputZp());
  addValue(op.getWeightZp());
  addType(op.getAccType());
  addValue(op.getOutput());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
  return populateProfileInfoConv(op);
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
  return populateProfileInfoConv(op);
}

template <>
LogicalResult
ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
  return populateProfileInfoConv(op);
}

template <>
LogicalResult
ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
  return populateProfileInfoConv(op);
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
  addValue(op.getInput1());
  addValue(op.getPadConst());
  addValue(op.getOutput());
  return success();
}

template <typename T>
LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
  addValue(op.getInput1());
  addValue(op.getOutput());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
  return populateProfileInfoDataLayout(op);
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
  return populateProfileInfoDataLayout(op);
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
  return populateProfileInfoDataLayout(op);
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
  return populateProfileInfoDataLayout(op);
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
  addValue(op.getValues());
  addValue(op.getIndices());
  addValue(op.getOutput());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
  addValue(op.getValuesIn());
  addValue(op.getIndices());
  addValue(op.getInput());
  addValue(op.getValuesOut());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
  addValue(op.getInput1());
  addValue(op.getInput2());
  addValue(op.getOutput());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
  addValue(op.getInput());
  addValue(op.getOutput());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
  addValue(op.getInputReal());
  addValue(op.getInputImag());
  addValue(op.getOutputReal());
  addValue(op.getOutputImag());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
  addValue(op.getInputReal());
  addValue(op.getOutputReal());
  addValue(op.getOutputImag());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
  addValue(op.getOnTrue());
  addValue(op.getOnFalse());
  addValue(op.getOutput());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
  addValue(op.getInput());
  addValue(op.getInputZp());
  addValue(op.getOutputZp());
  addValue(op.getOutput());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
  addValue(op.getA());
  addValue(op.getB());
  addValue(op.getAZp());
  addValue(op.getBZp());
  addValue(op.getOutput());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
  addType(op.getType());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
  addValue(op.getInput1());
  return success();
}

LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// This helper function only populates the info for the customised operands.
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)                                   \
  if (isa<tosa::tosaOp##Op>(op)) {                                             \
    return populateProfileInfo(cast<tosa::tosaOp##Op>(op));                    \
  }

#define POPULATE_PROFILE_INFO_SKIP(tosaOp)                                     \
  if (isa<tosa::tosaOp##Op>(op))                                               \
    return success();

// This helper function populates the info for all operands.
#define POPULATE_PROFILE_INFO_COMMON(tosaOp)                                   \
  if (isa<tosa::tosaOp##Op>(op)) {                                             \
    return populateProfileInfo(op->getOperands(), op->getResult(0));           \
  }

  // Skip irrelevant operands when they are independent and not tied to any
  // specific profile/extension.
  POPULATE_PROFILE_INFO_CUSTOM(AvgPool2d)
  POPULATE_PROFILE_INFO_CUSTOM(TransposeConv2D)
  POPULATE_PROFILE_INFO_CUSTOM(Conv2D)
  POPULATE_PROFILE_INFO_CUSTOM(Conv3D)
  POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D)
  POPULATE_PROFILE_INFO_CUSTOM(Mul)
  POPULATE_PROFILE_INFO_CUSTOM(FFT2d)
  POPULATE_PROFILE_INFO_CUSTOM(RFFT2d)
  POPULATE_PROFILE_INFO_CUSTOM(Concat)
  POPULATE_PROFILE_INFO_CUSTOM(Pad)
  POPULATE_PROFILE_INFO_CUSTOM(Reshape)
  POPULATE_PROFILE_INFO_CUSTOM(Slice)
  POPULATE_PROFILE_INFO_CUSTOM(Tile)
  POPULATE_PROFILE_INFO_CUSTOM(Transpose)
  POPULATE_PROFILE_INFO_CUSTOM(Gather)
  POPULATE_PROFILE_INFO_CUSTOM(Scatter)
  POPULATE_PROFILE_INFO_CUSTOM(Resize)
  POPULATE_PROFILE_INFO_CUSTOM(Select)
  POPULATE_PROFILE_INFO_CUSTOM(Rescale)
  POPULATE_PROFILE_INFO_CUSTOM(MatMul)
  POPULATE_PROFILE_INFO_CUSTOM(Variable)
  POPULATE_PROFILE_INFO_CUSTOM(VariableWrite)

  // For the most of tosa operators, all operands are profile/extension related
  // and hence are all considered in this profile-based compilance check.
  POPULATE_PROFILE_INFO_COMMON(Cast)
  POPULATE_PROFILE_INFO_COMMON(Const)
  POPULATE_PROFILE_INFO_COMMON(ArgMax)
  POPULATE_PROFILE_INFO_COMMON(Sub)
  POPULATE_PROFILE_INFO_COMMON(Maximum)
  POPULATE_PROFILE_INFO_COMMON(Minimum)
  POPULATE_PROFILE_INFO_COMMON(MaxPool2d)
  POPULATE_PROFILE_INFO_COMMON(Clamp)
  POPULATE_PROFILE_INFO_COMMON(Erf)
  POPULATE_PROFILE_INFO_COMMON(Sigmoid)
  POPULATE_PROFILE_INFO_COMMON(Tanh)
  POPULATE_PROFILE_INFO_COMMON(Add)
  POPULATE_PROFILE_INFO_COMMON(ArithmeticRightShift)
  POPULATE_PROFILE_INFO_COMMON(BitwiseAnd)
  POPULATE_PROFILE_INFO_COMMON(BitwiseNot)
  POPULATE_PROFILE_INFO_COMMON(BitwiseOr)
  POPULATE_PROFILE_INFO_COMMON(BitwiseXor)
  POPULATE_PROFILE_INFO_COMMON(LogicalLeftShift)
  POPULATE_PROFILE_INFO_COMMON(LogicalRightShift)
  POPULATE_PROFILE_INFO_COMMON(LogicalAnd)
  POPULATE_PROFILE_INFO_COMMON(LogicalNot)
  POPULATE_PROFILE_INFO_COMMON(LogicalOr)
  POPULATE_PROFILE_INFO_COMMON(LogicalXor)
  POPULATE_PROFILE_INFO_COMMON(IntDiv)
  POPULATE_PROFILE_INFO_COMMON(Pow)
  POPULATE_PROFILE_INFO_COMMON(Table)
  POPULATE_PROFILE_INFO_COMMON(Abs)
  POPULATE_PROFILE_INFO_COMMON(Ceil)
  POPULATE_PROFILE_INFO_COMMON(Clz)
  POPULATE_PROFILE_INFO_COMMON(Sin)
  POPULATE_PROFILE_INFO_COMMON(Cos)
  POPULATE_PROFILE_INFO_COMMON(Exp)
  POPULATE_PROFILE_INFO_COMMON(Floor)
  POPULATE_PROFILE_INFO_COMMON(Log)
  POPULATE_PROFILE_INFO_COMMON(Negate)
  POPULATE_PROFILE_INFO_COMMON(Reciprocal)
  POPULATE_PROFILE_INFO_COMMON(Rsqrt)
  POPULATE_PROFILE_INFO_COMMON(ReduceAll)
  POPULATE_PROFILE_INFO_COMMON(ReduceAny)
  POPULATE_PROFILE_INFO_COMMON(ReduceMax)
  POPULATE_PROFILE_INFO_COMMON(ReduceMin)
  POPULATE_PROFILE_INFO_COMMON(ReduceProduct)
  POPULATE_PROFILE_INFO_COMMON(ReduceSum)
  POPULATE_PROFILE_INFO_COMMON(Equal)
  POPULATE_PROFILE_INFO_COMMON(GreaterEqual)
  POPULATE_PROFILE_INFO_COMMON(Greater)
  POPULATE_PROFILE_INFO_COMMON(Reverse)
  POPULATE_PROFILE_INFO_COMMON(Identity)
  POPULATE_PROFILE_INFO_COMMON(VariableRead)

  // Type Invariant Extension, a capability extension that is independent
  // of the data type, meaning any compatible type can be used. No type
  // constraint for those operations.
  POPULATE_PROFILE_INFO_SKIP(ConstShape)
  POPULATE_PROFILE_INFO_SKIP(Yield)
  POPULATE_PROFILE_INFO_SKIP(If)
  POPULATE_PROFILE_INFO_SKIP(While)

  return failure();
}

//===----------------------------------------------------------------------===//
// Tosa Profile And Extension Compliance Checker
//===----------------------------------------------------------------------===//

template <typename T>
FailureOr<OpComplianceInfo<T>>
TosaProfileCompliance::getOperatorDefinition(Operation *op) {
  const std::string opName = op->getName().getStringRef().str();
  const auto complianceMap = getProfileComplianceMap<T>();
  const auto it = complianceMap.find(opName);
  if (it == complianceMap.end())
    return {};

  return findMatchedEntry<T>(op, it->second);
}

template <typename T>
LogicalResult TosaProfileCompliance::checkProfileOrExtension(
    Operation *op, const tosa::TargetEnv &targetEnv,
    const SmallVector<ArrayRef<T>> &specRequiredModeSet) {

  // None of profile requirement is set in the specification.
  if (specRequiredModeSet.size() == 0)
    return success();

  const auto maybeOpDefinition = getOperatorDefinition<T>(op);
  if (failed(maybeOpDefinition)) {
    // Operators such as control-flow and shape ops do not have an operand type
    // restriction. When the profile compliance information of operation is not
    // found, confirm if the target have enabled the profile required from the
    // specification.
    int modeCount = 0;
    for (const auto &cands : specRequiredModeSet) {
      if (targetEnv.allowsAnyOf(cands))
        return success();
      modeCount += cands.size();
    }

    op->emitOpError() << "illegal: requires"
                      << (modeCount > 1 ? " any of " : " ") << "["
                      << llvm::join(stringifyProfile<T>(specRequiredModeSet),
                                    ", ")
                      << "] but not enabled in target\n";

    return failure();
  }

  // Find the required profiles or extensions according to the operand type
  // combination.
  const auto opDefinition = maybeOpDefinition.value();
  const SmallVector<T> opRequiredMode = opDefinition.mode;
  const CheckCondition condition = opDefinition.condition;

  if (opRequiredMode.size() == 0) {
    // No matched restriction found.
    return success();
  }

  if (condition == CheckCondition::allOf &&
      !targetEnv.allowsAllOf(opRequiredMode)) {
    op->emitOpError() << "illegal: requires"
                      << (opRequiredMode.size() > 1 ? " all of " : " ") << "["
                      << llvm::join(stringifyProfile<T>(opRequiredMode), ", ")
                      << "] but not enabled in target\n";
    return failure();
  }

  if (condition == CheckCondition::anyOf &&
      !targetEnv.allowsAnyOf(opRequiredMode)) {
    op->emitOpError() << "illegal: requires"
                      << (opRequiredMode.size() > 1 ? " any of " : " ") << "["
                      << llvm::join(stringifyProfile<T>(opRequiredMode), ", ")
                      << "] but not enabled in target\n";
    return failure();
  }

  // Each extension can contain a list of profiles that it works with, usually
  // have the same data type.
  if constexpr (std::is_same_v<T, Extension>) {
    for (const auto &mode : opRequiredMode) {
      SmallVector<Profile> coProfs = getCooperativeProfiles(mode);
      if (!targetEnv.allowsAnyOf(coProfs)) {
        op->emitOpError() << "illegal: requires ["
                          << llvm::join(stringifyProfile<Profile>(coProfs),
                                        ", ")
                          << "] to work with but not enabled in target\n";
        return failure();
      }
    }
  }

  // Ensure the profile inference match the profile knowledge of the
  // specification.
  for (const auto &cands : specRequiredModeSet) {
    for (const auto &mode : opRequiredMode) {
      if (!llvm::is_contained(cands, mode)) {
        op->emitOpError() << "illegal: requires ["
                          << llvm::join(stringifyProfile<T>(opRequiredMode),
                                        ", ")
                          << "] but not included in the profile compliance ["
                          << llvm::join(
                                 stringifyProfile<T>(specRequiredModeSet), ", ")
                          << "]\n";
        return failure();
      }
    }
  }

  // Ensure the matched op compliance version does not exceed the target
  // specification version.
  const VersionedTypeInfo versionedTypeInfo =
      opDefinition.operandTypeInfoSet[0];
  const TosaSpecificationVersion complianceVersion{versionedTypeInfo.second};
  const TosaSpecificationVersion targetVersion{targetEnv.getSpecVersion()};
  if (!targetVersion.isBackwardsCompatibleWith(complianceVersion)) {
    op->emitOpError() << "illegal: the target specification version ("
                      << stringifyVersion(targetVersion)
                      << ") is not backwards compatible with the op compliance "
                         "specification version ("
                      << stringifyVersion(complianceVersion) << ")\n";
    return failure();
  }

  return success();
}

LogicalResult
TosaProfileCompliance::checkProfile(Operation *op,
                                    const tosa::TargetEnv &targetEnv) {
  if (auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
    return checkProfileOrExtension<Profile>(op, targetEnv,
                                            interface.getProfiles());

  return success();
}

LogicalResult
TosaProfileCompliance::checkExtension(Operation *op,
                                      const tosa::TargetEnv &targetEnv) {
  if (auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
    return checkProfileOrExtension<Extension>(op, targetEnv,
                                              interface.getExtensions());

  return success();
}

LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
  const auto maybeProfDef = getOperatorDefinition<Profile>(op);
  const auto maybeExtDef = getOperatorDefinition<Extension>(op);
  if (failed(maybeProfDef) && failed(maybeExtDef))
    return success();

  const bool hasEntry =
      (succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) ||
      (succeeded(maybeExtDef) && !maybeExtDef->mode.empty());
  if (!hasEntry) {
    std::string message;
    llvm::raw_string_ostream os(message);
    os << "illegal: operation operand/result data types did not align with any "
          "profile or extension, got (";

    ProfileInfoDepot depot(op);
    SmallVector<TypeInfo> current = depot.getInfo();
    for (const auto &typeInfo : llvm::drop_end(current))
      os << stringifyTypeInfo(typeInfo) << ",";
    os << stringifyTypeInfo(current.back()) << ")";

    // avoid polluting the error message output by outputting only
    // the best match
    const std::string opName = op->getName().getStringRef().str();
    int maxMatches = -1;
    SmallVector<TypeInfo> bestTypeInfo;
    const auto searchBestMatch = [&](auto map) {
      for (const auto &complianceInfos : map[opName]) {
        for (const auto &versionedTypeInfos :
             complianceInfos.operandTypeInfoSet) {
          const SmallVector<TypeInfo> typeInfos = versionedTypeInfos.first;
          const int matches = llvm::count_if(
              llvm::zip_equal(current, typeInfos), [&](const auto zipType) {
                return isSameTypeInfo(std::get<0>(zipType),
                                      std::get<1>(zipType));
              });
          if (matches > maxMatches) {
            maxMatches = matches;
            bestTypeInfo = typeInfos;
          }
        }
      }
    };
    searchBestMatch(getProfileComplianceMap<Profile>());
    searchBestMatch(getProfileComplianceMap<Extension>());

    os << ", did you mean (";
    for (const auto &typeInfo : llvm::drop_end(bestTypeInfo))
      os << stringifyTypeInfo(typeInfo) << ",";
    os << stringifyTypeInfo(bestTypeInfo.back()) << ")? ";
    os << "Otherwise, please refer to the 'supported data types' for '"
       << opName << "' in the specification.";
    op->emitOpError(message);
    return failure();
  }

  return success();
}

// Find the profiles or extensions requirement according to the signature of
// type of the operand list.
template <typename T>
OpComplianceInfo<T> TosaProfileCompliance::findMatchedEntry(
    Operation *op, SmallVector<OpComplianceInfo<T>> compInfo) {
  assert(compInfo.size() != 0 &&
         "profile-based compliance information is empty");

  // Populate the type of profile/extension relevant operands.
  ProfileInfoDepot depot(op);
  SmallVector<TypeInfo> present = depot.getInfo();
  if (present.size() == 0)
    return {};

  for (size_t i = 0; i < compInfo.size(); i++) {
    SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet;
    for (const auto &set : sets) {
      SmallVector<TypeInfo> expected = set.first;
      assert(present.size() == expected.size() &&
             "the entries for profile-based compliance do not match between "
             "the generated metadata and the type definition retrieved from "
             " the operation");

      bool isFound = true;
      // Compare the type signature between the given operation and the
      // compliance metadata.
      for (size_t j = 0; j < expected.size(); j++) {
        if (!isSameTypeInfo(present[j], expected[j])) {
          // Verify the next mode set from the list.
          isFound = false;
          break;
        }
      }

      if (isFound == true) {
        SmallVector<VersionedTypeInfo> typeInfoSet{set};
        OpComplianceInfo<T> info{compInfo[i].mode, typeInfoSet,
                                 compInfo[i].condition};
        return info;
      }
    }
  }

  return {};
}

// Debug utilites.
template <typename T>
SmallVector<StringRef>
TosaProfileCompliance::stringifyProfile(ArrayRef<T> profiles) {
  SmallVector<StringRef> debugStrings;
  for (const auto &profile : profiles) {
    if constexpr (std::is_same_v<T, Profile>)
      debugStrings.push_back(tosa::stringifyProfile(profile));
    else
      debugStrings.push_back(tosa::stringifyExtension(profile));
  }
  return debugStrings;
}

template <typename T>
SmallVector<StringRef> TosaProfileCompliance::stringifyProfile(
    const SmallVector<ArrayRef<T>> &profileSet) {
  SmallVector<StringRef> debugStrings;

  for (const auto &profiles : profileSet) {
    auto tempStrings = stringifyProfile<T>(profiles);
    llvm::append_range(debugStrings, tempStrings);
  }

  return debugStrings;
}

llvm::SmallString<7>
TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) {
  if (typeInfo.typeID == mlir::IntegerType::getTypeID()) {
    return {"i" + llvm::utostr(typeInfo.bitWidth)};
  }
  if (typeInfo.typeID == mlir::Float16Type::getTypeID()) {
    return {"f16"};
  } else if (typeInfo.typeID == mlir::Float32Type::getTypeID()) {
    return {"f32"};
  } else if (typeInfo.typeID == mlir::BFloat16Type::getTypeID()) {
    return {"bf16"};
  } else if (typeInfo.typeID == mlir::Float8E4M3FNType::getTypeID()) {
    return {"fp8e4m3"};
  } else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) {
    return {"fp8e5m2"};
  }
  llvm_unreachable("unknown type");
}
